{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d56d7546-3664-4ea0-85ca-59add93161f4",
   "metadata": {},
   "source": [
    "## Lesson 4: Applications of Embeddings"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7cdc073e-d1ac-49af-97b0-b73fb840d38c",
   "metadata": {},
   "source": [
    "#### Project environment setup\n",
    "\n",
    "- Load credentials and relevant Python Libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30dd137e-8522-4e6c-b981-6c5863bc159e",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "from utils import authenticate\n",
    "credentials, PROJECT_ID = authenticate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ce0ce04-d90d-4016-a636-5fc4c5578818",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "REGION = 'us-central1'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2a787d9-b471-41dc-89e2-e2e35d2e3a4c",
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "import vertexai\n",
    "vertexai.init(project=PROJECT_ID, \n",
    "              location=REGION, \n",
    "              credentials = credentials)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a84c7df8-5404-4697-804c-3e91a500a125",
   "metadata": {},
   "source": [
    "#### Load Stack Overflow questions and answers from BigQuery\n",
    "- BigQuery is Google Cloud's serverless data warehouse.\n",
    "- We'll get the first 500 posts (questions and answers) for each programming language: Python, HTML, R, and CSS."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "737db5e4-a5c4-480c-ab43-ae8e7db75c64",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "from google.cloud import bigquery\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0031e73e-517e-4a3f-b738-f75de03ec6b3",
   "metadata": {
    "height": 403
   },
   "outputs": [],
   "source": [
    "def run_bq_query(sql):\n",
    "\n",
    "    # Create BQ client\n",
    "    bq_client = bigquery.Client(project = PROJECT_ID, \n",
    "                                credentials = credentials)\n",
    "\n",
    "    # Try dry run before executing query to catch any errors\n",
    "    job_config = bigquery.QueryJobConfig(dry_run=True, \n",
    "                                         use_query_cache=False)\n",
    "    bq_client.query(sql, job_config=job_config)\n",
    "\n",
    "    # If dry run succeeds without errors, proceed to run query\n",
    "    job_config = bigquery.QueryJobConfig()\n",
    "    client_result = bq_client.query(sql, \n",
    "                                    job_config=job_config)\n",
    "\n",
    "    job_id = client_result.job_id\n",
    "\n",
    "    # Wait for query/job to finish running. then get & return data frame\n",
    "    df = client_result.result().to_arrow().to_pandas()\n",
    "    print(f\"Finished job_id: {job_id}\")\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1adceeb-c109-43b2-bda9-489e1ab69413",
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "# define list of programming language tags we want to query\n",
    "\n",
    "language_list = [\"python\", \"html\", \"r\", \"css\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9e0cba0-daaf-4475-99d9-bb4a6ae72874",
   "metadata": {
    "height": 505
   },
   "outputs": [],
   "source": [
    "so_df = pd.DataFrame()\n",
    "\n",
    "for language in language_list:\n",
    "    \n",
    "    print(f\"generating {language} dataframe\")\n",
    "    \n",
    "    query = f\"\"\"\n",
    "    SELECT\n",
    "        CONCAT(q.title, q.body) as input_text,\n",
    "        a.body AS output_text\n",
    "    FROM\n",
    "        `bigquery-public-data.stackoverflow.posts_questions` q\n",
    "    JOIN\n",
    "        `bigquery-public-data.stackoverflow.posts_answers` a\n",
    "    ON\n",
    "        q.accepted_answer_id = a.id\n",
    "    WHERE \n",
    "        q.accepted_answer_id IS NOT NULL AND \n",
    "        REGEXP_CONTAINS(q.tags, \"{language}\") AND\n",
    "        a.creation_date >= \"2020-01-01\"\n",
    "    LIMIT \n",
    "        500\n",
    "    \"\"\"\n",
    "\n",
    "    \n",
    "    language_df = run_bq_query(query)\n",
    "    language_df[\"category\"] = language\n",
    "    so_df = pd.concat([so_df, language_df], \n",
    "                      ignore_index = True) "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bc1eeb1a-c3da-42bd-986c-2d8beb31eef4",
   "metadata": {},
   "source": [
    "- You can reuse the above code to run your own queries if you are using Google Cloud's BigQuery service.\n",
    "- In this classroom, if you run into any issues, you can load the same data from a csv file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9b70a17-7441-427b-94e3-8c902c01a057",
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "# Run this cell if you get any errors or you don't want to wait for the query to be completed\n",
    "# so_df = pd.read_csv('so_database_app.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f912a8d-a391-4fda-a7b4-eee4c0d8490d",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "so_df"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2107a70c-d6a4-48ef-9798-f6bdbf44aa27",
   "metadata": {},
   "source": [
    "#### Generate text embeddings\n",
    "- To generate embeddings for a dataset of texts, we'll need to group the sentences together in batches and send batches of texts to the model.\n",
    "- The API currently can take batches of up to 5 pieces of text per API call."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "678f7521-52ef-40a7-b258-a47a2b14ebc5",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "from vertexai.language_models import TextEmbeddingModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b1902c7-6a47-4a43-bf15-6de20ed5d7f7",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "model = TextEmbeddingModel.from_pretrained(\n",
    "    \"textembedding-gecko@001\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b8da7eb-cc45-4621-ac3c-bf08ee2e282a",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "import time\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d531fad6-cfde-4aa7-9486-f10ab4efa71d",
   "metadata": {
    "height": 97
   },
   "outputs": [],
   "source": [
    "# Generator function to yield batches of sentences\n",
    "\n",
    "def generate_batches(sentences, batch_size = 5):\n",
    "    for i in range(0, len(sentences), batch_size):\n",
    "        yield sentences[i : i + batch_size]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf78aa64-5b40-490f-85a5-f54127684823",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "so_questions = so_df[0:200].input_text.tolist() \n",
    "batches = generate_batches(sentences = so_questions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c1e9d08-5949-4046-987c-ea13b26dad0a",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "batch = next(batches)\n",
    "len(batch)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70c12ffc-dab7-41fd-93e4-a6baef7ab221",
   "metadata": {},
   "source": [
    "#### Get embeddings on a batch of data\n",
    "- This helper function calls `model.get_embeddings()` on the batch of data, and returns a list containing the embeddings for each text in that batch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe933993-1cdf-460b-afad-0173b2282e8d",
   "metadata": {
    "height": 114
   },
   "outputs": [],
   "source": [
    "def encode_texts_to_embeddings(sentences):\n",
    "    try:\n",
    "        embeddings = model.get_embeddings(sentences)\n",
    "        return [embedding.values for embedding in embeddings]\n",
    "    except Exception:\n",
    "        return [None for _ in range(len(sentences))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9190606-9d73-4455-ab4b-97df20f0a2fb",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "batch_embeddings = encode_texts_to_embeddings(batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "259f5f5d-0c6b-4030-9a45-92438625b792",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "f\"{len(batch_embeddings)} embeddings of size \\\n",
    "{len(batch_embeddings[0])}\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "debb7234-eccd-4dc8-bb4b-abd7a21c7714",
   "metadata": {},
   "source": [
    "#### Code for getting data on an entire data set\n",
    "- Most API services have rate limits, so we've provided a helper function (in utils.py) that you could use to wait in-between API calls.\n",
    "- If the code was not designed to wait in-between API calls, you may not receive embeddings for all batches of text.\n",
    "- This particular service can handle 20 calls per minute.  In calls per second, that's 20 calls divided by 60 seconds, or `20/60`.\n",
    "\n",
    "```Python\n",
    "from utils import encode_text_to_embedding_batched\n",
    "\n",
    "so_questions = so_df.input_text.tolist()\n",
    "question_embeddings = encode_text_to_embedding_batched(\n",
    "                            sentences=so_questions,\n",
    "                            api_calls_per_second = 20/60, \n",
    "                            batch_size = 5)\n",
    "```\n",
    "\n",
    "In order to handle limits of this classroom environment, we're not going to run this code to embed all of the data. But you can adapt this code for your own projects and datasets."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fdd85d39-72d8-432d-9210-55e14cac508c",
   "metadata": {},
   "source": [
    "#### Load the data from file\n",
    "- We'll load the stack overflow questions, answers, and category labels (Python, HTML, R, CSS) from a .csv file.\n",
    "- We'll load the embeddings of the questions (which we've precomputed with batched calls to `model.get_embeddings()`), from a pickle file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87c78630-2cda-487a-be24-5ade753d76f9",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "so_df = pd.read_csv('so_database_app.csv')\n",
    "so_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a413056-0ec7-4137-a191-814e4797e0a2",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7836318-0148-456d-b29e-4a26aa4f8221",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "with open('question_embeddings_app.pkl', 'rb') as file:\n",
    "    question_embeddings = pickle.load(file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a4134a8-76d3-4677-b180-f8477b1a2c15",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "print(\"Shape: \" + str(question_embeddings.shape))\n",
    "print(question_embeddings)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3061eb7d-b69c-4eae-b583-67e9d7ca1f47",
   "metadata": {},
   "source": [
    "#### Cluster the embeddings of the Stack Overflow questions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a58b189d-61f7-4c68-8c73-769c59a23987",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "from sklearn.cluster import KMeans\n",
    "from sklearn.decomposition import PCA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd2cd054-426e-4d92-9561-285a0b55792e",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "clustering_dataset = question_embeddings[:1000]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05421596-894c-4d53-9551-61839e72c0eb",
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "n_clusters = 2\n",
    "kmeans = KMeans(n_clusters=n_clusters, \n",
    "                random_state=0, \n",
    "                n_init = 'auto').fit(clustering_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7fac5a79-7639-4072-8c9d-86dc2e2c6658",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "kmeans_labels = kmeans.labels_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "210b0fbe-a5df-488f-9814-60f4304804d0",
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "PCA_model = PCA(n_components=2)\n",
    "PCA_model.fit(clustering_dataset)\n",
    "new_values = PCA_model.transform(clustering_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d2095e0-605f-4dd3-9786-2f313ff5c60a",
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import mplcursors\n",
    "%matplotlib ipympl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "961b770b-53ea-4b49-bfc8-507cd711fc99",
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "from utils import clusters_2D\n",
    "clusters_2D(x_values = new_values[:,0], y_values = new_values[:,1], \n",
    "            labels = so_df[:1000], kmeans_labels = kmeans_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b499cf9d-c739-4a34-b38d-e270fffe14f7",
   "metadata": {},
   "source": [
    "- Clustering is able to identify two distinct clusters of HTML or Python related questions, without being given the category labels (HTML or Python)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c46c217f-d121-4a14-a351-340a134db73f",
   "metadata": {},
   "source": [
    "## Anomaly / Outlier detection\n",
    "\n",
    "- We can add an anomalous piece of text and check if the outlier (anomaly) detection algorithm (Isolation Forest) can identify it as an outlier (anomaly), based on its embedding."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f7e3a9e-ec14-4f85-9eef-5766f85e0d0c",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "from sklearn.ensemble import IsolationForest"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96d872a2-5fbb-466b-ac80-84a1c6458006",
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "input_text = \"\"\"I am making cookies but don't \n",
    "                remember the correct ingredient proportions. \n",
    "                I have been unable to find \n",
    "                anything on the web.\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ffb3b6f7-8cc8-4f3b-a19e-541956903e9e",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "emb = model.get_embeddings([input_text])[0].values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b89da1d8-d17f-4ac7-a774-7fc0f67fba00",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "embeddings_l = question_embeddings.tolist()\n",
    "embeddings_l.append(emb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f98eb7a-9a62-4d08-8744-5618cdf6d241",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "embeddings_array = np.array(embeddings_l)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8a6eaa7-1067-4414-9e40-05352eb5ec03",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "print(\"Shape: \" + str(embeddings_array.shape))\n",
    "print(embeddings_array)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11467308-b48a-4341-a918-83940239762f",
   "metadata": {
    "height": 131
   },
   "outputs": [],
   "source": [
    "# Add the outlier text to the end of the stack overflow dataframe\n",
    "so_df = pd.read_csv('so_database_app.csv')\n",
    "new_row = pd.Series([input_text, None, \"baking\"], \n",
    "                    index=so_df.columns)\n",
    "so_df.loc[len(so_df)+1] = new_row\n",
    "so_df.tail()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "da8b00f6-bc36-4e07-87af-043609e273c4",
   "metadata": {},
   "source": [
    "#### Use Isolation Forest to identify potential outliers\n",
    "\n",
    "- `IsolationForest` classifier will predict `-1` for potential outliers, and `1` for non-outliers.\n",
    "- You can inspect the rows that were predicted to be potential outliers and verify that the question about baking is predicted to be an outlier."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36815aaf-fa2a-4e0f-98d0-07c334a79887",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "clf = IsolationForest(contamination=0.005, \n",
    "                      random_state = 2) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5781dcd9-cdce-4307-ba49-dccf9cd7646e",
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "preds = clf.fit_predict(embeddings_array)\n",
    "\n",
    "print(f\"{len(preds)} predictions. Set of possible values: {set(preds)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "483e753a-5291-423b-a7a3-c36af012d3f4",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "so_df.loc[preds == -1]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4237f497-489e-448c-bf52-781a31d28558",
   "metadata": {},
   "source": [
    "#### Remove the outlier about baking"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8da293f-7945-4686-93fb-0a62269fed07",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "so_df = so_df.drop(so_df.index[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "869a7b82-a813-474d-98d3-d740a2e8e6c5",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "so_df"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a14167fd-21f5-4929-9ec3-ff85a9a4de9e",
   "metadata": {},
   "source": [
    "## Classification\n",
    "- Train a random forest model to classify the category of a Stack Overflow question (as either Python, R, HTML or CSS)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "319bf0c2-92d1-403b-acd9-919a3e21033f",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "from sklearn.ensemble import RandomForestClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6dd46eb2-445a-4e1f-97a0-8c995bfaf892",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.model_selection import train_test_split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6a79406-57bb-4480-9184-1ba9bebd5c36",
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "# re-load the dataset from file\n",
    "so_df = pd.read_csv('so_database_app.csv')\n",
    "X = question_embeddings\n",
    "X.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10826eca-8fa0-4ce7-9b93-aa3e8fb19e2e",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "y = so_df['category'].values\n",
    "y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6a3ac12-b97a-4345-aea4-498de0babb38",
   "metadata": {
    "height": 97
   },
   "outputs": [],
   "source": [
    "X_train, X_test, y_train, y_test = train_test_split(X, \n",
    "                                                    y, \n",
    "                                                    test_size = 0.2, \n",
    "                                                    random_state = 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b65f8db9-7c39-42bd-ac01-8217688454ec",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "clf = RandomForestClassifier(n_estimators=200)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e6a0d25-9708-41b8-a754-bdccf5b94b66",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "clf.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d6eda9b-25f5-4cbd-8ab1-462361f48a21",
   "metadata": {},
   "source": [
    "#### You can check the predictions on a few questions from the test set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e134eb4-b3cd-425a-81dd-e2a96362c782",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "y_pred = clf.predict(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0805c099-1349-45a0-b1bc-47b64b15495a",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "accuracy = accuracy_score(y_test, y_pred) # compute accuracy\n",
    "print(\"Accuracy:\", accuracy)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6ae498a1-26ee-4111-90ba-b625804353de",
   "metadata": {},
   "source": [
    "#### Try out the classifier on some questions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1641137e-f556-443b-b67a-7468229ce7ee",
   "metadata": {
    "height": 250
   },
   "outputs": [],
   "source": [
    "# choose a number between 0 and 1999\n",
    "i = 2\n",
    "label = so_df.loc[i,'category']\n",
    "question = so_df.loc[i,'input_text']\n",
    "\n",
    "# get the embedding of this question and predict its category\n",
    "question_embedding = model.get_embeddings([question])[0].values\n",
    "pred = clf.predict([question_embedding])\n",
    "\n",
    "print(f\"For question {i}, the prediction is `{pred[0]}`\")\n",
    "print(f\"The actual label is `{label}`\")\n",
    "print(\"The question text is:\")\n",
    "print(\"-\"*50)\n",
    "print(question)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "932aa937-9b26-4e5e-8242-001cc2794a75",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
